Skip to content

[PyTorch] Fusible ops preserve usages in quantized weight tensors#2929

Merged
timmoon10 merged 8 commits intoNVIDIA:mainfrom
timmoon10:tmoon/debug-alternate-train-infer
May 1, 2026
Merged

[PyTorch] Fusible ops preserve usages in quantized weight tensors#2929
timmoon10 merged 8 commits intoNVIDIA:mainfrom
timmoon10:tmoon/debug-alternate-train-infer

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 commented Apr 25, 2026

Description

We have experienced a correctness error in the grouped linear op when alternating between training and validation. During validation steps, we configure the weight quantizer without column-wise usage, which can cause quantized weight tensors to have stale column-wise data in the next training step. This PR makes sure to avoid disabling preserve quantizer usages in quantized weight tensors, since a usage that was required in the past may be required in the future.

Related: #2222

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Avoid removing Preserve usages from quantized weight tensors in basic linear op
  • Avoid removing Preserve usages from quantized weight tensors in grouped linear op
  • Add tests for the linear op that alternate between training and inference steps

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Quantized weight tensor may be used across steps, so removing a usage is not safe.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added bug Something isn't working 2.15.0 labels Apr 25, 2026
Comment on lines -332 to -333
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was made incorrect in #1817.

@timmoon10

This comment was marked as outdated.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 25, 2026

Greptile Summary

This PR fixes a correctness bug where validation steps were stripping columnwise usage from quantized weight tensors, leaving them with stale/incomplete data when returning to training. The fix is applied symmetrically to both BasicLinear and GroupedLinear by changing reset_recipe_state to use OR-based usage preservation (never downgrade an existing columnwise=True) and updating pre_fuser_forward to pass columnwise=requires_grad instead of always False. New TestTrainingLoops tests exercise alternating train/infer sequences for both quantized and non-quantized weight configurations.

Confidence Score: 5/5

Safe to merge — all findings are P2 style/robustness notes; the bug-fix logic is sound.

No P0 or P1 issues found. The core fix (OR-based usage preservation in reset_recipe_state) is correct and well-tested by the new alternating train/infer test cases. Minor P2s: _linear_train_stage assumes module.bias is not None, and the single_grouped_weight vs non-single weight_is_quantized asymmetry warrants a comment, but neither affects correctness in any currently exercised code path.

grouped_linear.py — the single_grouped_weight weight_is_quantized detection path deserves a clarifying comment.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/basic_linear.py Core fix: reset_recipe_state now preserves existing columnwise usage from the weight tensor's quantizer (OR logic), preventing inference steps from wiping column-wise data needed by subsequent training steps. pre_fuser_forward correctly changed from columnwise=False to columnwise=requires_grad.
transformer_engine/pytorch/ops/basic/grouped_linear.py Mirrors the basic_linear.py fix for grouped weights. Uses OR-based usage preservation. Asymmetry in weight_is_quantized check for single_grouped_weight (weight.quantizer is not None) vs. non-single (is_quantized_tensor) is intentional but undocumented.
tests/pytorch/test_fusible_ops.py New TestTrainingLoops class adds train↔infer alternation tests for te.ops.Linear. _linear_train_stage unconditionally accesses module.bias, which will crash for bias=False modules.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[reset_recipe_state called] --> B{weight_quantizer exists?}
    B -- No --> Z[Done]
    B -- Yes --> C[Set weight_quantizer.internal based on FP8 params / quantized weight flags]
    C --> D[Apply recipe-specific config force_pow_2_scales, amax_epsilon, etc.]
    D --> E{is weight a QuantizedTensor or has .quantizer set?}
    E -- No --> Z
    E -- Yes --> F[Read existing quantizer from weight tensor weight._quantizer or weight.quantizer]
    F --> G{Existing quantizer has usages?}
    G -- Yes --> H[OR-merge usages: rowwise = existing OR new, columnwise = existing OR new]
    G -- No --> I[Keep new quantizer usages as-is]
    H --> J[Update weight tensor quantizer with merged settings]
    I --> J
    J --> Z
    style H fill:#c8e6c9,stroke:#388e3c
    style G fill:#fff9c4,stroke:#f9a825
Loading

Reviews (3): Last reviewed commit: "Blindly preserve quantizer usages in qua..." | Re-trigger Greptile

Turns out we still need this in case the quantizer is used before the forward, e.g. in previous ops or CPU offloading.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

1 similar comment
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Comment thread transformer_engine/pytorch/ops/basic/grouped_linear.py
Comment on lines +678 to +701
# Update quantizer in quantized weight tensor
if weight_quantizer is not None and weight_is_quantized:
# Get quantizer from weight tensor
weight_tensor_quantizer = (
weight.quantizer if self.single_grouped_weight else weight._quantizer
)

# Set quantizer usages
# Note: Avoid disabling usages that are already set. The
# weight tensor may be reused across steps, so future
# steps may need usages that are currently unnecessary.
weight_quantizer.set_usage(rowwise=True)
columnwise_usage = torch.is_grad_enabled()
if weight_tensor_quantizer is not None and weight_tensor_quantizer.columnwise_usage:
columnwise_usage = True
if columnwise_usage:
weight_quantizer.set_usage(columnwise=True)

# Update weight tensor
if self.single_grouped_weight:
if group_idx == 0:
weight.quantizer = weight_quantizer.copy()
else:
weight.update_quantizer(weight_quantizer.copy())
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this change of updating quantizer in case of quantized_model_init?

In general, for the quantized_model_init case, I am not sure if it even makes sense to modify the quantizer usages of the quantized_tensor ever during the lifecyle of a module being created.

For example lets say if we create module under torch.no_grad context manager. columnwise_usage will be set to False and if we enter a training loop, we ll be modifying the quantizer usage without modifying the parameters. Leading to quantizer and quantized tensor being in an inconsistent state.

Now in that case we can technically dequantize and quantize it back to have both usages, but we suffer dequantization errors in those case(which might be ok?).

I am wondering should we even touch the quantizer usages in case of quantized_model_init at all after the module is initialized?(for the quantized model parameters)

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the quantizer semantics are not straightforward in this case. I agree it's a bad idea to put much effort into fixing incorrect usages, and better to error out instead of requantizing. However, I can imagine cases where the quantization recipe changes over time. Maybe for FP8 DS, you start off with a long amax history for stability, but later you shorten the history once the model has slightly converged. Or maybe for FP8 block-scaling, you enable/disable power-of-two scales at different points in training.

I think we should distinguish between the quantizer and the quantized tensor. Once you have a quantized tensor (from initialization, the optimizer step, FSDP, etc), you're stuck with the data inside of it. However, it seems reasonable to change the quantizer so that you can change future casts. This is a bit tricky, since there's no guarantee that the quantizer and the quantized tensor currently match.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cases you mentioned is not something that would happen every iteration right? I think it might be worth keeping quantized_tensor'data contents consistent with its metadata like quantizer and other things you mentioned.

If the future casts would anyway need it(most likely through dequantization), why not correct the data now itself? It would make debugging issues a lot smoother if quantizer and quantized_tensor are in sync with each other.

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still assume that the recipe changes infrequently, and this logic only happens when the recipe changes.

We should make sure that the quantized tensor stores all the information needed to dequantize, and should not depend at all on the quantizer. In this case, there's no need to keep the quantizer and quantized tensor in sync, and in fact it limits us (e.g. if the quantized tensor is created by hand rather than from a quantizer). Also, in general we should not assume that quantizer configs are relevant for dequantization. For example, stochastic rounding affects quantization, but afterwards there is no functional difference between SR and non-SR data. Similarly, FP8 DS and FP8 CS have very different quantization, but dequantization is identical.

Comment on lines +411 to +425
# Update quantizer in quantized weight tensor
if weight_quantizer is not None and is_quantized_tensor(weight):
# Set quantizer usages
# Note: Avoid disabling usages that are already set. The
# weight tensor may be reused across steps, so future
# steps may need usages that are currently unnecessary.
weight_quantizer.set_usage(rowwise=True)
columnwise_usage = torch.is_grad_enabled()
if weight._quantizer is not None and weight._quantizer.columnwise_usage:
columnwise_usage = True
if columnwise_usage:
weight_quantizer.set_usage(columnwise=True)

# Update weight tensor
weight.update_quantizer(weight_quantizer.copy())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you move this to its own section after the other quantizers? Logically, it should follow the rest of the setup for the weight quantizer.
Also, this logic is a little strange to me. Ultimately what it does is setting the columnwise usage if the grad is enabled and then keeping it forever. So it should be something like this:

Suggested change
# Update quantizer in quantized weight tensor
if weight_quantizer is not None and is_quantized_tensor(weight):
# Set quantizer usages
# Note: Avoid disabling usages that are already set. The
# weight tensor may be reused across steps, so future
# steps may need usages that are currently unnecessary.
weight_quantizer.set_usage(rowwise=True)
columnwise_usage = torch.is_grad_enabled()
if weight._quantizer is not None and weight._quantizer.columnwise_usage:
columnwise_usage = True
if columnwise_usage:
weight_quantizer.set_usage(columnwise=True)
# Update weight tensor
weight.update_quantizer(weight_quantizer.copy())
# Update quantizer in quantized weight tensor
if weight_quantizer is not None and is_quantized_tensor(weight):
# Set quantizer usages
# Note: Avoid disabling usages that are already set. The
# weight tensor may be reused across steps, so future
# steps may need usages that are currently unnecessary.
if weight._quantizer is None or (not weight._quantizer.columnwise_usage and torch.is_grad_enabled()):
weight_quantizers.set_usage(rowwise=True, columnwise=True)
weight.update_quantizer(weight_quantizer.copy())

(the assumption here is that the rowwise would always be there for the weights).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no. I see that the context of this function is the resetting of the recipe, so sure, we need to create the new quantizer and pass it to the tensor. This makes sense. I don't see the requantization happening in this case though - if we have weight as already quantized tensor then where is the code to actually apply this change in the _quantizer to the data held by it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big problem with the existing code is that we update the weight param quantizer, and then we do more weight quantizer configuration afterwards. We should make sure the weight quantizer is fully configured, and only then update the quantized param. The logic for updating the quantized param was also a little convoluted, so I attempted to make it a bit more clear.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think requantization really makes sense: #2929 (comment)

I can see the argument that we should blindly preserve usages in the weight param quantizer. This logic may be trying too hard to clean up after a user doing something weird.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree with that (if I change the recipe to a different type then trying to run without this requantization would just fail since e.g. we would try to multiply the MXFP8 tensor with FP8 CS tensor) but I think this should be addressed in its own PR rather than here.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 changed the title [PyTorch] Avoid removing usages from quantized weight tensors [PyTorch] Preserve usages in quantized weight tensors Apr 30, 2026
@timmoon10 timmoon10 changed the title [PyTorch] Preserve usages in quantized weight tensors [PyTorch] Fusbile ops preserve usages in quantized weight tensors Apr 30, 2026
@timmoon10 timmoon10 changed the title [PyTorch] Fusbile ops preserve usages in quantized weight tensors [PyTorch] Fusible ops preserve usages in quantized weight tensors Apr 30, 2026
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@timmoon10 timmoon10 merged commit 88e6071 into NVIDIA:main May 1, 2026
10 of 14 checks passed
KshitijLakhani pushed a commit that referenced this pull request May 1, 2026
)

* Avoid removing usages from quantized weight in linear op

Quantized weight tensor may be used across steps, so removing a usage is not safe.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Tweak test to catch bug when alternating train and infer steps

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Avoid removing usages from quantized weights in grouped linear op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Restore pre-forward quantizer config in ops

Turns out we still need this in case the quantizer is used before the forward, e.g. in previous ops or CPU offloading.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Blindly preserve quantizer usages in quantized weight params.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.15.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants